from .mlp import MLPDecoder, ResMLPDecoder
from .zinb import ZINBMLPDecoder, ZINBResMLPDecoder
from torch import nn

def setup_decoder(model_type, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, standardscale=1e4) -> nn.Module:
    if model_type == 'zinbmlp':
        mod = ZINBMLPDecoder(
            in_dim=in_dim,
            hidden_dim=hidden_dim,
            out_dim=out_dim,
            num_layers=num_layers,
            dropout=dropout,
            norm=norm,
            batch_num=batch_num,
            standardscale=standardscale,
        )
    elif model_type == 'zinbresmlp':
        mod = ZINBResMLPDecoder(
            in_dim = in_dim,
            hidden_dim = hidden_dim,
            out_dim = out_dim,
            num_layers = num_layers,
            dropout = dropout,
            norm = norm,
            batch_num=batch_num,
            standardscale=standardscale,
        )
    elif model_type == 'mlp':
        mod = MLPDecoder(
            in_dim = in_dim,
            hidden_dim = hidden_dim,
            out_dim = out_dim,
            num_layers = num_layers,
            dropout = dropout,
            norm = norm,
            batch_num=batch_num,
        )
    elif model_type == "resmlp":
        mod = ResMLPDecoder(
            in_dim=in_dim,
            hidden_dim=hidden_dim,
            out_dim=out_dim,
            num_layers=num_layers,
            dropout=dropout,
            norm=norm,
            batch_num=batch_num,
        )
    else:
        raise NotImplementedError(f'Unsupported model type: {model_type}')
    return mod